﻿/**
* CRL
*/
using CRL.Core.Extension;
using CRL.Data.Attribute;
using CRL.Data.DBAccess;
using System;
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace CRL.Data.DBAdapter
{

	public partial class MSSQLDBAdapter : DBAdapterBase
	{
		public MSSQLDBAdapter(DbContextInner _dbContext)
			: base(_dbContext)
		{
		}
		public override bool CanCompileSP
		{
			get
			{
				return true;
			}
		}
		#region 创建结构

		/// <summary>
		/// 创建存储过程脚本
		/// </summary>
		/// <param name="spName"></param>
		/// <returns></returns>
		public override string GetCreateSpScript(string spName, string script)
		{
			string template = string.Format(@"
if not exists(select * from sysobjects where name='{0}' and type='P')
begin
exec sp_executesql N' {1} '
end", spName, script);
			return template;
		}

		/// <summary>
		/// 获取字段类型映射
		/// </summary>
		/// <returns></returns>
		public override Dictionary<Type, string> FieldMaping()
		{
			Dictionary<Type, string> dic = new Dictionary<Type, string>();
			//字段类型对应
			dic.Add(typeof(System.String), "nvarchar({0})");
			dic.Add(typeof(System.Decimal), "decimal(18, 4)");
			dic.Add(typeof(System.Double), "float");
			dic.Add(typeof(System.Single), "real");
			dic.Add(typeof(System.Boolean), "bit");
			dic.Add(typeof(System.Int32), "int");
			dic.Add(typeof(System.Int16), "SMALLINT");
			dic.Add(typeof(System.Enum), "int");
			dic.Add(typeof(System.Byte), "tinyint");
			dic.Add(typeof(System.DateTime), "datetime2");
			dic.Add(typeof(System.UInt16), "SMALLINT");
			dic.Add(typeof(System.Int64), "bigint");
			dic.Add(typeof(System.Object), "nvarchar(30)");
			dic.Add(typeof(System.Byte[]), "varbinary({0})");
			dic.Add(typeof(System.Guid), "uniqueidentifier");
			return dic;
		}
		/// <summary>
		/// 获取列类型和默认值
		/// </summary>
		/// <param name="info"></param>
		/// <param name="defaultValue"></param>
		/// <returns></returns>
		public override string GetColumnType(Attribute.FieldInnerAttribute info, out string defaultValue)
		{
			Type propertyType = info.PropertyType;
			//Dictionary<Type, string> dic = GetFieldMaping();
			defaultValue = info.DefaultValue;
			if (info.ValueNeedConvert)
			{
				propertyType = typeof(string);
				info.Length = 8000;
			}
			//int默认值
			if (string.IsNullOrEmpty(defaultValue))
			{
				if (!info.IsPrimaryKey && propertyType == typeof(System.Int32))
				{
					defaultValue = "0";
				}
				//datetime默认值
				if (propertyType == typeof(System.DateTime))
				{
					defaultValue = "getdate()";
				}
			}
			string columnType;
			columnType = GetDBColumnType(propertyType);
			//超过3000设为ntext
			if (propertyType == typeof(System.String) && info.Length > 3000)
			{
				columnType = "ntext";
			}
			if (info.Length > 0)
			{
				columnType = string.Format(columnType, info.Length);
			}
			if (info.IsPrimaryKey)
			{
				if (info.KeepIdentity == true)
				{
					columnType = columnType + " ";
				}
				else
				{
					//todo 只有数值型才能自增
					if (info.PropertyType == typeof(int) || info.PropertyType == typeof(long))
					{
						columnType = columnType + " IDENTITY(1,1) ";
					}
				}
			}
			if (!string.IsNullOrEmpty(info.ColumnType))
			{
				columnType = info.ColumnType;
			}
			return columnType;
		}

		/// <summary>
		/// 创建字段脚本
		/// </summary>
		/// <param name="field"></param>
		/// <returns></returns>
		public override string GetCreateColumnScript(DbContextInner dbContext, Attribute.FieldInnerAttribute field)
		{
			var table = TypeCache.GetTable(field.ModelType);
			var tableName = TypeCache.GetTableName(table.TableName, dbContext);
			var columnType = GetColumnType(field, out var defaultValue);

			string str = string.Format("alter table [{0}] add [{1}] {2}", tableName, field.MapingName, columnType);
			if (!string.IsNullOrEmpty(defaultValue))
			{
				str += string.Format(" default({0})", defaultValue);
			}
			if (field.NotNull)
			{
				str += " not null";
			}
			return str;
		}

        public override string GetCreateIndexScript(string owner, TableInnerAttribute table, bool unique, string indexName, params string[] columns)
        {
			var tableName = KeyWordFormat(table.TableName);
            var index = unique ? "unique index" : "index";
            var script = $"create {index} {indexName} on {tableName} ({string.Join(",", columns.ToArray())});";
            return script;
        }

        /// <summary>
        /// 创建表脚本
        /// </summary>
        /// <param name="fields"></param>
        /// <param name="tableName"></param>
        /// <returns></returns>
        public override void CreateTable(DbContextInner dbContext, List<Attribute.FieldInnerAttribute> fields, string tableName)
		{
			var defaultValues = new List<string>();
			string script = string.Format("create table [{0}] (\r\n", tableName);
			List<string> list2 = new List<string>();
			string primaryKey = "";
			foreach (Attribute.FieldInnerAttribute item in fields)
			{
				if (item.IsPrimaryKey)
				{
					primaryKey = item.MapingName;
				}
				var columnType = GetColumnType(item, out var defaultValue);
				string nullStr = item.NotNull ? "NOT NULL" : "";
				string str = string.Format("[{0}] {1} {2} ", item.MapingName, columnType, nullStr);
				list2.Add(str);
				//生成默认值语句
				if (!string.IsNullOrEmpty(defaultValue))
				{
					string v = string.Format("ALTER TABLE [dbo].[{0}] ADD  CONSTRAINT [DF_{0}_{1}]  DEFAULT ({2}) FOR [{1}]", tableName, item.MapingName, defaultValue);
					defaultValues.Add(v);
				}
			}
			script += string.Join(",\r\n", list2.ToArray());
			if (!string.IsNullOrEmpty(primaryKey))
			{
				script += string.Format(@" CONSTRAINT [PK_{0}] PRIMARY KEY CLUSTERED 
(
	[{1}] ASC
)WITH (PAD_INDEX  = OFF, STATISTICS_NORECOMPUTE  = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS  = ON, ALLOW_PAGE_LOCKS  = ON) ON [PRIMARY]
", tableName, primaryKey);

			}
			script += ") ON [PRIMARY]";
			//var list3 = GetIndexScript();
			//defaultValues.AddRange(list3);
			var helper = dbContext.DBHelper;
			helper.Execute(script);
			foreach (string s in defaultValues)
			{
				if (!string.IsNullOrEmpty(s))
				{
					helper.Execute(s);
				}
			}
		}
		#endregion
		public override DBType DBType
		{
			get { return DBType.MSSQL; }
		}
		#region SQL查询

		public override string GetTableFields(string tableName)
		{
			string sql = $"select c.name, TYPE_NAME(c.system_type_id) as type FROM sys.objects obj JOIN sys.columns c ON c.object_id = obj.object_id WHERE obj.name='{tableName}'";
			return sql;
		}

		/// <summary>
		/// 插入对象,并返回主键
		/// </summary>
		/// <param name="obj"></param>
		/// <returns></returns>
		public override object InsertObject<T>(DbContextInner dbContext, T obj)
		{
			Type type = obj.GetType();
			var helper = dbContext.DBHelper;
			var table = TypeCache.GetTable(type);
			var primaryKey = table.PrimaryKey;

			var sql = GetInsertSql(dbContext, table, obj);
			if (primaryKey == null)
			{
				SqlStopWatch.Execute(helper, sql);
				return null;
			}
			if (primaryKey.KeepIdentity == true)
			{
				SqlStopWatch.Execute(helper, sql);
				return primaryKey.GetValue(obj);
			}
			else
			{
				sql += ";select scope_identity() ;";
				return SqlStopWatch.ExecScalar(helper, sql);
			}
		}

		public override void InsertOrUpdate(DbContextInner dbContext, IList items, InsertOrUpdateOption option)
		{
			if (items.Count == 0)
				return;
			option = option ?? new InsertOrUpdateOption();
			var type = items[0].GetType();
			var table = TypeCache.GetTable(type);
			if (table.PrimaryKey == null)
			{
				throw new Exception($"InsertOrUpdate {table.Type} 缺少主键");
			}
			var tableName = KeyWordFormat(TypeCache.GetTableName(table.TableName, dbContext));
			var constraintName = "";
			if (!string.IsNullOrEmpty(option.ConstraintMemberName))
			{
				constraintName = option.ConstraintMemberName;
			}
			else
			{
				constraintName = KeyWordFormat(table.PrimaryKey.MapingName);
			}
			//如果指定了唯一约束名称，当ID为自增时，保持自增
			//id没有值，认证为自增
			var autoIncrement = false;//自增
			var firstKeyValue = table.PrimaryKey.GetValue(items[0]);
			var allFields = table.Fields.AsQueryable();
			if (table.PrimaryKey.PropertyType.IsNumeric() && !string.IsNullOrEmpty(option.ConstraintMemberName))
			{
				var keyValue = Convert.ToInt64(firstKeyValue);
				if (keyValue == 0)
				{
					autoIncrement = true;
					allFields = allFields.Where(b => !b.IsPrimaryKey);
				}
			}
			var sb = new StringBuilder();

			if (!autoIncrement)
			{
				sb.AppendLine($"SET IDENTITY_INSERT {tableName} ON;");
			}
			sb.AppendLine($"MERGE INTO {tableName} t1 ");
			sb.AppendLine("USING (");
			sb.AppendLine(getMergeSelect(dbContext, allFields, items));
			sb.AppendLine($") t2 ON (t1.{constraintName} = t2.{constraintName})");
			if (!option.IfExistsNotUpdate)
			{
				sb.AppendLine("WHEN MATCHED THEN");//update
				string getCondition(List<FieldInnerAttribute> fields)
				{
					return string.Join(",", fields.Select(b => $"{KeyWordFormat(b.MapingName)}=t2.{KeyWordFormat(b.MapingName)}"));
				}
				var updateFields = table.Fields.Where(b => !b.IsPrimaryKey);
				if (option?.UpdateMemberNames?.Any() == true)
				{
					updateFields = updateFields.Where(b => option.UpdateMemberNames.Contains(b.MemberName));
				}
				sb.AppendLine($"update set {getCondition(updateFields.ToList())}");
			}
			sb.AppendLine("WHEN NOT MATCHED THEN ");//insert
			string getFields(string prex)
			{
				return string.Join(",", allFields.Select(b => $"{prex}{KeyWordFormat(b.MapingName)}"));
			}
			sb.AppendLine($"insert ({getFields("")})  values ({getFields("t2.")});");
			if (!autoIncrement)
			{
				sb.AppendLine($"SET IDENTITY_INSERT {tableName} OFF;");
			}
			option.SqlOut = sb.ToString();
			dbContext.DBHelper.Execute(option.SqlOut);
		}

		public override void BatchInsert(DbContextInner dbContext, System.Collections.IList details, bool keepIdentity = false)
		{
			if (details.Count == 0)
				return;
			var helper = dbContext.DBHelper as SqlHelper;
			//var sql = getBatchInsertSql(dbContext, details, keepIdentity);
			//helper.Execute(sql);

			var tempTable = GetBatchInsertTable(dbContext, details, keepIdentity);
			helper.InsertFromDataTable(tempTable, keepIdentity);
		}
		string getMergeSelect(DbContextInner dbContext, IQueryable<FieldInnerAttribute> allFields, IList items)
		{
			var sb = new StringBuilder();
			var pIndex = 0;
			for (var i = 0; i < items.Count; i++)
			{
				var item = items[i];
				var fields = new List<string>();
				foreach (var f in allFields)
				{
					pIndex++;
					var v = f.GetValue(item);
					if (v is bool || v is Enum)
					{
						v = Convert.ToInt32(v);
					}
					//var pName = GetParamName(f.MemberName, pIndex);
					//dbContext.DBHelper.AddParam(pName, v);
					var str = f.PropertyType.IsNumeric() ? $"{v}" : $"'{v}'";
					if (i == 0)
					{
						str += $" as {KeyWordFormat(f.MapingName)}";
					}
					fields.Add(str);
				}
				if (i > 0)
				{
					sb.AppendLine("UNION ALL");
				}
				sb.AppendLine($"select {string.Join(",", fields)}");
			}
			return sb.ToString();
		}
		/// <summary>
		/// 获取 with(nolock)
		/// </summary>
		/// <returns></returns>
		public override string GetWithNolockFormat(bool v)
		{
			if (!v)
			{
				return "";
			}
			return " with (nolock)";
		}
		/// <summary>
		/// 获取前几条语句
		/// </summary>
		/// <param name="fields">id,name</param>
		/// <param name="query">from table where 1=1</param>
		/// <param name="sort"></param>
		/// <param name="top"></param>
		/// <returns></returns>
		public override void GetSelectFull(StringBuilder sb, string fields, Action<StringBuilder> query, string sort, int top)
		{
			//string sql = string.Format("select {0} {1} {2} {3}", top == 0 ? "" : "top " + top, fields, query, sort);
			//string sql = "select " + (top == 0 ? "" : "top " + top) + fields + query + sort;
			//return sql;
			//sb.AppendFormat("select {0} {1} {2} {3}", top == 0 ? "" : "top " + top, fields, query, sort);
			//return;
			sb.Append("select ");
			if (top > 0)
			{
				sb.AppendFormat("top {0} ", top);
			}
			sb.Append(fields);
			query(sb);
			if (!string.IsNullOrEmpty(sort))
			{
				sb.Append(sort);
			}
		}
		#endregion

		#region 系统查询
		public override string GetAllTablesSql(string db)
		{
			return "select Lower(name),name from sysobjects where  type='u'";
		}
		public override string GetAllSPSql(string db)
		{
			return "select name,id from sysobjects where  type='P'";
		}
		#endregion

		#region 模版
		public override string SpParameFormat(string name, string type, bool output)
		{
			string str = "";
			if (!output)
			{
				str = "@{0} {1},";
			}
			else
			{
				str = "@{0} {1} output,";
			}
			return string.Format(str, name, type);
		}
        static ConcurrentDictionary<string, string> KeyWordFormatCache = new ConcurrentDictionary<string, string>();
        public override string KeyWordFormat(string value)
		{
            return KeyWordFormatCache.GetOrAdd(value, $"[{value}]");
            return string.Format("[{0}]", value);
		}
		//public override string FieldNameFormat(Attribute.FieldAttribute field)
		//{
		//    if(string.IsNullOrEmpty(field.MapingNameFormat))
		//    {
		//        return field.MapingName;
		//    }
		//    return field.MapingNameFormat;
		//}
		public override string TemplateGroupPage
		{
			get
			{
				string str = @"
--group分页
CREATE PROCEDURE [dbo].{name}
{parame}
--参数传入 @pageSize,@pageIndex
AS
set  nocount  on
declare @start nvarchar(20) 
declare @end nvarchar(20)
declare @pageCount INT

begin

    --获取记录数
	  select @count=count(0) from (select count(*) as a  {sql}) t
    if @count = 0
    return
    if @count = 0
        set @count = 1

    --取得分页总数
    set @pageCount=(@count+@pageSize-1)/@pageSize

    /**当前页大于总页数 取最后一页**/
    --if @pageIndex>@pageCount
        --set @pageIndex=@pageCount

	--计算开始结束的行号
	set @start = @pageSize*(@pageIndex-1)+1
	set @end = @start+@pageSize-1 
	select * FROM (select {fields},ROW_NUMBER() OVER ( Order by {rowOver} ) AS RowNumber {sql}) T WHERE T.RowNumber BETWEEN @start AND @end 
end
";
				return str;
			}
		}

		public override string TemplatePage
		{
			get
			{
				string str = @"
--表分页
CREATE PROCEDURE [dbo].{name}
{parame}
--参数传入 @pageSize,@pageIndex
AS
set  nocount  on
declare @start nvarchar(20) 
declare @end nvarchar(20)
declare @pageCount INT

begin

    --获取记录数
	  select @count=count(0) {sql}
    if @count = 0
    return
    if @count = 0
        set @count = 1

    --取得分页总数
    set @pageCount=(@count+@pageSize-1)/@pageSize

    /**当前页大于总页数 取最后一页**/
    --if @pageIndex>@pageCount
        --set @pageIndex=@pageCount

	--计算开始结束的行号
	set @start = @pageSize*(@pageIndex-1)+1
	set @end = @start+@pageSize-1 
	select * FROM (select {fields},ROW_NUMBER() OVER ( Order by {rowOver} ) AS RowNumber {sql}) T WHERE T.RowNumber BETWEEN @start AND @end order by RowNumber
end

";
				return str;
			}
		}

		public override string TemplateSp
		{
			get
			{
				string str = @"
CREATE PROCEDURE [dbo].{name}
{parame}
AS
set  nocount  on
	{sql}
";
				return str;
			}
		}
		public override string SqlFormat(string sql)
		{
			return sql;
		}
		#endregion

		#region 函数格式化
		public override string SubstringFormat(string field, int index, int length)
		{
			return string.Format(" SUBSTRING({0},{1},{2})", field, index + 1, length);
		}

		public override string StringLikeFormat(string field, string parName)
		{
			return string.Format("{0} LIKE {1}", field, parName);
		}

		public override string StringNotLikeFormat(string field, string parName)
		{
			return string.Format("{0} NOT LIKE {1}", field, parName);
		}

		public override string StringContainsFormat(string field, string parName)
		{
			return string.Format("CHARINDEX({1},{0})>0", field, parName);
		}
		public override string StringNotContainsFormat(string field, string parName)
		{
			return string.Format("CHARINDEX({1},{0})<=0", field, parName);
		}
		public override string BetweenFormat(string field, string parName, string parName2)
		{
			return string.Format("{0} between {1} and {2}", field, parName, parName2);
		}

		public override string DateDiffFormat(string field, string format, string parName)
		{
			return string.Format("DateDiff({0},{1},{2})", format, field, parName);
		}

		public override string InFormat(string field, string parName)
		{
			return string.Format("{0} IN ({1})", field, parName);
		}
		public override string NotInFormat(string field, string parName)
		{
			return string.Format("{0} NOT IN ({1})", field, parName);
		}
		#endregion

		public override string PageSqlFormat(DBHelper db, string fields, string rowOver, string condition, int start, int end, string sort)
		{
			string sql = "select * FROM (select {0},ROW_NUMBER() OVER ( Order by {1} ) AS RowNumber {2}) T WHERE T.RowNumber BETWEEN {3} AND {4} order by RowNumber";
			return string.Format(sql, fields, rowOver, condition, start, end);
		}
		public override string GetRelationUpdateSql(string t1, string t2, string condition, string setValue, LambdaQuery.LambdaQueryBase query)
		{
			if (condition.ToLower().Contains(" join "))
			{
				string table = string.Format("{0} t1", KeyWordFormat(t1), KeyWordFormat(t2));
				string sql = string.Format("update t1 set {0} from {1} {2}", setValue, table, condition);
				return sql;
			}
			else
			{
				setValue = setValue.Replace("t1.", "");
				condition = condition.Replace("t1.", $"{KeyWordFormat(t1)}.");
				return $"update {KeyWordFormat(t1)} set {setValue} from {t2} t2 where {condition}";
			}
		}
		public override string GetRelationDeleteSql(string t1, string t2, string condition, LambdaQuery.LambdaQueryBase query)
		{
			string table = string.Format("{0} t1", KeyWordFormat(t1), KeyWordFormat(t2));
			string sql = string.Format("delete t1 from {0} {1}", table, condition);
			return sql;
		}
		public override string GetFieldConcat(string field, object value, Type type)
		{
			string str;
			if (type == typeof(string))
			{
				str = string.Format("{0}+'{1}'", field, value);
			}
			else
			{
				str = string.Format("{0}+{1}", field, value);
			}
			return str;
		}
		public override string CastField(string field, Type fieldType)
		{
			var dic = FieldMaping();
			if (!dic.ContainsKey(fieldType))
			{
				throw new Exception(string.Format("没找到对应类型的转换{0} 在字段{1}", fieldType, field));
			}
			var type = dic[fieldType];
			type = string.Format(type, 100);
			return string.Format("CAST({0} as {1})", field, type);
		}
		public override string GetParamName(string name, object index)
		{
			return string.Format("@{0}{1}", name, index);
		}
		public override string DateTimeFormat(string field, string format)
		{
			return string.Format("CONVERT(varchar(100), {0}, {1})", field, format);
		}
		public override string GetSplitFirst(string field, string parName)
		{
			return $"substring({field},1,charindex('{parName}',{field})-1)";
		}
		public override Dictionary<string, long> GetFieldLength(DbContextInner dbContext, string tableName)
		{
			var sql = @"select 
TableName = OBJECT_NAME(c.object_id), 
ColumnsName = c.name, 
ColumnType=t.name, 
Length=c.max_length 
FROM 
sys.columns c 

left outer join 
systypes t 
on c.system_type_id=t.xtype 
WHERE 
OBJECTPROPERTY(c.object_id, 'IsMsShipped')=0 
AND OBJECT_NAME(c.object_id) ='{0}'
and c.max_length>20 and t.name!='sysname'";
			sql = string.Format(sql, tableName);
			var helper = dbContext.DBHelper;
			var dt = helper.ExecDataTable(sql);
			var dic = new Dictionary<string, long>();
			foreach (System.Data.DataRow dr in dt.Rows)
			{
				var columnType = dr["ColumnType"].ToString();
				var length = Convert.ToInt64(dr["Length"]);
				if (columnType == "nvarchar")
				{
					length = length / 2;
				}
				dic.Add(dr["ColumnsName"].ToString(), length);
			}
			return dic;
		}
		public override void UpdateTableComment(DBHelper db, string table, string comment)
		{
			db.AddParam("name", "MS_Description");
			db.AddParam("value", comment);
			db.AddParam("level0type", "user");
			db.AddParam("level0name", "dbo");
			db.AddParam("level1type", "table");
			db.AddParam("level1name", table);
			db.AddParam("level2type", null);
			db.AddParam("level2name", null);
			try
			{
				db.Run("sp_updateextendedproperty");
			}
			catch
			{
				db.Run("sp_addextendedproperty");
			}
			db.ClearParams();
		}
		public override void UpdateFieldComment(DBHelper db, string table, FieldInnerAttribute field, string comment)
		{
			//EXEC sys.sp_addextendedproperty @name=N'MS_Description', @value=N'这里是ID' , @level0type=N'SCHEMA',@level0name=N'dbo', @level1type=N'TABLE',@level1name=N'Student', @level2type=N'COLUMN',@level2name=N'ID'
			db.AddParam("name", "MS_Description");
			db.AddParam("value", comment);
			db.AddParam("level0type", "user");
			db.AddParam("level0name", "dbo");
			db.AddParam("level1type", "table");
			db.AddParam("level1name", table);
			db.AddParam("level2type", "column");
			db.AddParam("level2name", field.MapingName);
			try
			{
				db.Run("sp_updateextendedproperty");
			}
			catch
			{
				db.Run("sp_addextendedproperty");
			}
			db.ClearParams();
		}
	}
}
